Fix nan leak from masked positions in compute_approx_kl#1635
Merged
erictang000 merged 2 commits intoMay 8, 2026
Merged
Conversation
Masking via `kld * loss_mask` propagates `nan` from masked positions because IEEE 754 defines `0 * nan = nan`, poisoning the downstream masked_mean and any metric (e.g. policy_kl, final_loss) that consumes the KL scalar. Switch to `masked_fill` so masked positions are forced to 0.0 regardless of the input value there. Autograd is unaffected. Add a parametrized regression test covering all four estimator types (k1, k2, k3, abs) that injects `nan` at a masked position and asserts the output and downstream `masked_mean` stay finite. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request updates the KL divergence computation to prevent nan leakage from masked positions by replacing direct multiplication with masked_fill. A corresponding unit test was added to verify the fix. The reviewer recommended using torch.where instead of masked_fill to preserve potential soft masking functionality while still addressing the nan leakage issue.
Switch the mask-sanitization step from `masked_fill(~mask.bool(), 0.0)`
to `torch.where(mask.bool(), kld * mask, 0.0)` so non-binary mask values
still scale the kept positions multiplicatively, while masked (mask==0)
positions are still forced to 0.0 so non-finite inputs there cannot leak.
Combine the two prior regression test cases into one parametrized test
(`test_compute_approx_kl_applies_loss_mask`) that exercises both
invariants in one shot: a soft mask `{1.0, 0.5, 0.25, 0.0}` with `nan`
injected at the masked position, asserting kept-position scaling and
masked-position zeroing for all four estimator types.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
erictang000
approved these changes
May 8, 2026
Collaborator
erictang000
left a comment
There was a problem hiding this comment.
nice, this lgtm, thanks!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Replace
kld * loss_maskwithtorch.where(loss_mask.bool(), kld * loss_mask, 0.0)incompute_approx_kl, sonanat masked positions (where0 * nan = nan) can no longer leak through intopolicy_kl/final_loss.Closes #1633